import visual_behavior_glm
import visual_behavior_glm.src.GLM_params as glm_params
import visual_behavior_glm.src.GLM_analysis_tools as gat
from visual_behavior_glm.src.glm import GLM
import matplotlib.pyplot as plt
import visual_behavior.data_access.loading as loading
import visual_behavior.database as db
import plotly.express as px
import pandas as pd
import numpy as np
import os
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
%matplotlib inline
import visual_behavior.data_access.loading as loading
experiments_table = loading.get_filtered_ophys_experiment_table()
# sessions_table = loading.get_filtered_ophys_session_table()
expts_with_licks = [984551228, 1010530054, 881949066, 980673831, 1007929142, 933338719]
# 459777_940852658_Camk2a-tTA_VISp_175_OPHYS_3_images_A_CAM2P3
# 480753_984551228_Sst-IRES-Cre_VISp_275_OPHYS_2_images_B_passive_CAM2P3
# 498972_1010530054_Sst-IRES-Cre_VISp_275_OPHYS_2_images_B_passive_CAM2P3
# 435431_881949066_Vip-IRES-Cre_VISp_225_OPHYS_2_images_A_passive_MESO1
# 477202_980673831_Sst-IRES-Cre_VISp_275_OPHYS_2_images_B_passive_CAM2P3
# 495789_1007929142_Sst-IRES-Cre_VISp_275_OPHYS_2_images_A_passive_CAM2P3
# 467951_933338719_Vip-IRES-Cre_VISp_175_OPHYS_2_images_B_passive_CAM2P3
experiments_table.loc[expts_with_licks][['ophys_session_id','cre_line','session_type','equipment_name','exposure_number']]
results_all = gat.retrieve_results(results_type='full')
results_all['glm_version'].unique()
#use v4
rs = gat.retrieve_results(search_dict = {'glm_version': '6_L2_optimize_by_cell'}, results_type='summary')
len(rs)
rs.keys()
rs['identifier'] = rs['ophys_experiment_id'].astype(str) + '_' + rs['cell_specimen_id'].astype(str)
rs
model_output_type = 'variance_explained'
ve = rs.pivot(index='identifier',columns='dropout',values=model_output_type).reset_index()
ve
cells_to_include = ve[ve['Full']>0.01].identifier.values
order = np.argsort(ve[ve.identifier.isin(cells_to_include)==True]['Full'])
cell_order = cells_to_include[order]
len(cells_to_include)
rs.keys()
# model_output_type = 'fraction_change_from_full'
model_output_type = 'adj_fraction_change_from_full'
rsp = rs.pivot(index='identifier',columns='dropout',values=model_output_type).reset_index()
rsp
tmp = ve.rename(columns={'Full':'varience_explained_full_model'})
rsp = rsp.merge(tmp[['identifier','varience_explained_full_model']], on=['identifier'])
# rsp = rsp[rsp.identifier.isin(cells_to_include)==True]
rs.keys()
rs.keys()
rspm = rsp.merge(rs[['identifier','cell_specimen_id','ophys_experiment_id','cre_line','session_type','imaging_depth','equipment_name','project_code','session_number','exposure_number','container_id']].drop_duplicates(),left_on='identifier',right_on='identifier',how='inner')
rspm
rspm.keys()
rspm.session_type.unique()
# def map_session_types(session_type):
# session_id = session_type[6:7]
# return session_id
# rspm['session_id'] = rspm['session_type'].map(lambda st:map_session_types(st))
# rspm['session_id'].unique()
# save = False
# if save:
# rspm.to_csv('/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/ophys_glm/fraction_change_var_explained_v_4_L2_fixed_lambda=1_2020.08.09.csv', index=False)
cols_for_clustering = [col for col in rspm.columns if col not in ['identifier','cre_line','session_type','equipment_name',
'session_id', 'imaging_depth','project_code','session_number','exposure_number']]
cols_for_clustering = [col for col in cols_for_clustering if col not in ['image0','image1','image2','image3',
'image4','image5','image6','image7',
'visual', 'Full']]
cols_for_clustering = [col for col in cols_for_clustering if 'single' not in col]
cols_for_clustering
len(cols_for_clustering)
cols_for_clustering = [
'all-images',
'omissions',
'pupil',
'running',
'rewards',
'face_motion_energy',
'image_expectation',
'change',
'hits',
'misses',
'correct_rejects',
'false_alarms',
'post_lick_bouts',
'post_licks',
'pre_lick_bouts',
'pre_licks',
'time',
'model_bias',
'model_omissions1',
'model_task0',
'model_timing1D',
]
len(cols_for_clustering)
rspm[cols_for_clustering]
feature_matrix = rspm[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, center=0, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
feature_matrix = rspm.sort_values('varience_explained_full_model').reset_index()[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, cmap='RdBu_r', center=0, ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
feature_matrix = rspm.sort_values('omissions').reset_index()[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
feature_matrix.dropna()
# feature_matrix = rspm.sort_values('omissions').reset_index()[cols_for_clustering]
# fig, ax = plt.subplots(figsize=(6,10))
# ax = sns.clustermap(feature_matrix.dropna(), vmin=-0.5, vmax=0.5, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
# ax.set_ylabel('cells')
# ax.set_title('clustered GLM feature matrix')
df = rspm[rspm.project_code.isin(['VisualBehavior', 'VisualBehaviorTask1B'])]
cell_specimen_ids = df.cell_specimen_id.unique()
cre_lines = df.cre_line.unique()
matched_cells = [cell_specimen_id for cell_specimen_id in cell_specimen_ids if len(df[df.cell_specimen_id==cell_specimen_id])>=3]
session_numbers = np.sort(df.session_number.unique())
for cre_line in cre_lines:
nrows = 4
ncols = 8
fig, ax = plt.subplots(nrows,ncols, figsize=(20,20), sharey=True)
ax = ax.ravel()
cre_df = df[df.cell_specimen_id.isin(matched_cells)&(df.cre_line==cre_line)]
cre_df = cre_df.sort_values(by='varience_explained_full_model', ascending=False)
cre_matched_cells = cre_df.cell_specimen_id.unique()
for i,cell_specimen_id in enumerate(cre_matched_cells[:(nrows*ncols)]):
cdf = df[df.cell_specimen_id==cell_specimen_id]
ax[i] = sns.heatmap(data=cdf.set_index('session_number')[cols_for_clustering].sort_values(by='session_number').T,
vmin=-0.5, vmax=0.5, cmap='RdBu_r', square=True, ax=ax[i], cbar=False)
# ax[i].imshow(cdf[cols_for_clustering].values.T, vmin=-0.5, vmax=0.5, cmap='RdBu_r',)
ax[i].set_xlim(-0.5,len(cdf)+0.5)
ax[i].set_xlabel('')
# ax[i].set_xticks(np.arange(0,len(session_numbers),1))
# ax[i].set_xticklabels(session_numbers)
ax[i].set_ylim(-0.5,len(cols_for_clustering)+0.5)
ax[i].set_yticks(np.arange(0.5,len(cols_for_clustering)+0.5,1))
ax[i].set_yticklabels(cols_for_clustering)
# ax[i].set_yticklabels('')
ax[i].set_title('csid: '+str(cell_specimen_id)+'\n '+cdf.session_type.values[0]+'\ndepth: '+str(int(cdf.imaging_depth.mean())), fontsize=10)
# for y in np.arange(0,nrows*ncols,ncols):
# ax[y].set_xlabel('session #', fontsize=10)
# for x in range(nrows):
# ax[(nrows*ncols-x)-1].set_yticks(np.arange(len(cols_for_clustering)))
# ax[(nrows*ncols-x)-1].set_yticklabels(cols_for_clustering, rotation=0)
plt.suptitle(cre_line+' single cell feature vectors\ncolormap is '+model_output_type+' from -0.5 to 0.5', x=0.52, y=1.03, horizontalalignment='center')
fig.tight_layout()
df = rspm[rspm.project_code.isin(['VisualBehavior', 'VisualBehaviorTask1B'])]
fig, ax = plt.subplots(1,3,figsize=(15,7))
for i,cre_line in enumerate(cre_lines):
tmp = df[df.cre_line==cre_line]
ax[i] = sns.heatmap(data=tmp.groupby(['session_number']).median()[cols_for_clustering].T,
ax=ax[i], square=True)
ax[i].set_xlim(-0.5,len(tmp.session_number.unique())+0.5)
ax[i].set_ylim(-0.5,len(cols_for_clustering)-0.5)
ax[i].set_yticks(np.arange(0.5,len(cols_for_clustering)+0.5,1))
ax[i].set_yticklabels(cols_for_clustering);
ax[i].set_title(cre_line)
fig.tight_layout()
df = rspm[rspm.project_code.isin(['VisualBehavior', 'VisualBehaviorTask1B'])]
for c,cre_line in enumerate(cre_lines):
tmp = df[(df.cre_line==cre_line)]
containers = np.sort(tmp.container_id.unique())
print(cre_line, len(containers))
fig, ax = plt.subplots(4,5, figsize=(20,20))
ax = ax.ravel()
for i,container_id in enumerate(containers):
cdf = tmp[(tmp.container_id==container_id)]
ax[i] = sns.heatmap(data=cdf.groupby(['session_number']).median()[cols_for_clustering].T, ax=ax[i])
ax[i].set_xlim(-0.5,len(cdf.session_number.unique())+0.5)
ax[i].set_ylim(-0.5,len(cols_for_clustering)-0.5)
ax[i].set_yticks(np.arange(0.5,len(cols_for_clustering)+0.5,1))
ax[i].set_yticklabels(cols_for_clustering);
ax[i].set_title(cre_line+' '+str(int(cdf.imaging_depth.mean()))+'\n'+str(container_id))
fig.tight_layout()
session_numbers = np.sort(df.session_number.unique())
for cre_line in cre_lines:
nrows = 4
ncols = 8
fig, ax = plt.subplots(nrows,ncols, figsize=(20,20), sharey=True)
ax = ax.ravel()
cre_df = df[df.cell_specimen_id.isin(matched_cells)&(df.cre_line==cre_line)]
cre_df = cre_df.sort_values(by='varience_explained_full_model', ascending=False)
cre_matched_cells = cre_df.cell_specimen_id.unique()
for i,cell_specimen_id in enumerate(cre_matched_cells[:(nrows*ncols)]):
cdf = df[df.cell_specimen_id==cell_specimen_id]
ax[i] = sns.heatmap(data=cdf.set_index('session_number')[cols_for_clustering].sort_values(by='session_number').T,
vmin=-0.5, vmax=0.5, cmap='RdBu_r', square=True, ax=ax[i], cbar=False)
# ax[i].imshow(cdf[cols_for_clustering].values.T, vmin=-0.5, vmax=0.5, cmap='RdBu_r',)
ax[i].set_xlim(-0.5,len(cdf)-0.5)
ax[i].set_xlabel('')
# ax[i].set_xticks(np.arange(0,len(session_numbers),1))
# ax[i].set_xticklabels(session_numbers)
ax[i].set_ylim(-0.5,len(cols_for_clustering)-0.5)
ax[i].set_yticks(np.arange(0,len(cols_for_clustering),1))
ax[i].set_yticklabels(cols_for_clustering)
# ax[i].set_yticklabels('')
ax[i].set_title('csid: '+str(cell_specimen_id)+'\n '+cdf.session_type.values[0]+'\ndepth: '+str(int(cdf.imaging_depth.mean())), fontsize=10)
# for y in np.arange(0,nrows*ncols,ncols):
# ax[y].set_xlabel('session #', fontsize=10)
# for x in range(nrows):
# ax[(nrows*ncols-x)-1].set_yticks(np.arange(len(cols_for_clustering)))
# ax[(nrows*ncols-x)-1].set_yticklabels(cols_for_clustering, rotation=0)
plt.suptitle(cre_line+' single cell feature vectors\ncolormap is '+model_output_type+' from -0.5 to 0.5', x=0.52, y=1.03, horizontalalignment='center')
fig.tight_layout()
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='all-images', hue='cre_line', ax=ax)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='pupil', hue='cre_line', ax=ax)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='pupil', y='running', hue='cre_line', ax=ax)
colors = sns.color_palette()
colors = [colors[0], colors[2], colors[3]]
cre_lines = np.sort(rspm.cre_line.unique())
for metric in cols_for_clustering:
fig, ax = plt.subplots(figsize=(6,4))
sns.pointplot(data=rspm, x='session_number', y=metric, hue='cre_line', hue_order=cre_lines, palette=colors, ax=ax)
cre_lines = np.sort(rspm.cre_line.unique())
session_numbers = np.sort(rspm.session_number.unique())
cre_line = cre_lines[0]
session_number = session_numbers[0]
def get_colors_for_session_numbers():
reds = sns.color_palette('Reds_r', 6)[::2]
blues = sns.color_palette('Blues_r', 6)[::2]
return reds + blues
colors = get_colors_for_session_numbers()
# colors = [c[0], c[2], c[3], c[5]]
colors = get_colors_for_session_numbers()
fig, ax = plt.subplots(1,3, figsize=(16, 8))
for i,cre_line in enumerate(cre_lines):
for c,session_number in enumerate(session_numbers):
data = rspm[(rspm.cre_line==cre_line)&(rspm.session_number==session_number)][cols_for_clustering].melt()
ax[i] = sns.pointplot(data=data, x='value', y='variable', ax=ax[i], color=colors[c], )
ax[i].set_title(cre_line)
plt.legend(labels=session_numbers)
fig.tight_layout()
colors = get_colors_for_session_numbers()
fig, ax = plt.subplots(1,3, figsize=(16, 8))
for i,cre_line in enumerate(cre_lines):
for c,session_number in enumerate(session_numbers):
data = rspm[(rspm.cre_line==cre_line)&(rspm.session_number==session_number)][cols_for_clustering].melt()
ax[i] = sns.pointplot(data=data, x='value', y='variable', ax=ax[i], color=colors[c], )
ax[i].set_title(cre_line)
ax[i].set_xlim(-0.05, 0.01)
plt.legend(labels=session_numbers)
fig.tight_layout()
g = sns.PairGrid(rspm[cols_for_clustering+['cre_line']], hue="cre_line")
g.map_diag(plt.hist)
g.map_offdiag(plt.scatter)
g.add_legend();
pca_result.shape
data.shape
pca_result[:,0].shape
data.shape
data = rspm.dropna()
n_features = len(cols_for_clustering)
n_components = len(cols_for_clustering)
pca = PCA(n_components=n_components)
pca_result = pca.fit_transform(data[cols_for_clustering].values)
data['pc1'] = pca_result[:,0]
data['pc2'] = pca_result[:,1]
data['pc3'] = pca_result[:,2]
print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))
np.cumsum(pca.explained_variance_ratio_)
np.searchsorted(np.cumsum(pca.explained_variance_ratio_), .90)
np.searchsorted(np.cumsum(pca.explained_variance_ratio_), .95)
fig,ax=plt.subplots()
ax.plot(
np.arange(n_components),
pca.explained_variance_ratio_,
'o-k'
)
ax.set_xlabel('PC number')
ax.set_ylabel('variance explained')
ax.set_title('first 10 PCs explain >95% of the variance')
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.components_, vmin=-1, vmax=1, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'weight'})
ax.set_ylabel('principal components')
ax.set_xlabel('features')
# ax.set_title('principal axes in feature space \n(directions of maximum variance in the data)')
ax.set_ylim(0, n_components)
ax.set_xticklabels(cols_for_clustering, rotation=90);
pca.components_.shape
fig,ax=plt.subplots(figsize=(12,4))
N_PCs = 8
for PC in range(N_PCs):
ax.plot(pca.components_[PC,:])
ax.legend(np.arange(N_PCs), title='PC', bbox_to_anchor=(1,1))
ax.axhspan(1/np.sqrt(32), -1/np.sqrt(32), zorder=-np.inf, alpha=0.25)
ax.set_xticks(np.arange(len(cols_for_clustering)))
ax.set_xticklabels(cols_for_clustering, rotation=45, ha='right')
ax.set_ylabel('weight')
ax.set_ylim(-1.1,1.1)
fig.tight_layout()
fig,ax=plt.subplots(figsize=(12,4))
for PC in range(8,20):
ax.plot(pca.components_[PC,:])
ax.legend(np.arange(10,21), title='PC', bbox_to_anchor=(1,1))
ax.axhspan(1/np.sqrt(32), -1/np.sqrt(32), zorder=-np.inf, alpha=0.25)
ax.set_xticks(np.arange(len(cols_for_clustering)))
ax.set_xticklabels(cols_for_clustering, rotation=45, ha='right')
ax.set_ylabel('weight')
ax.set_ylim(-1.1,1.1)
fig.tight_layout()
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.get_covariance(), vmin=-0.002, vmax=0.002, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'covariance'})
ax.set_title('covariance matrix')
ax.set_ylim(0, n_features)
ax.set_xticklabels(cols_for_clustering, rotation=90);
ax.set_yticklabels(cols_for_clustering, rotation=0);
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result, cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'activation'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('% variance explained in PC space')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(cols_for_clustering, rotation=90);
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result[np.argsort(pca_result[:,0])], cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'activation'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('% variance explained in PC space')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(cols_for_clustering, rotation=90);
# pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0] = sns.scatterplot(data=data, x="pc1", y="pc2", hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[0])
# ax[0].set_xlim(-5,10)
# ax[0].set_ylim(-5,10)
ax[1] = sns.scatterplot(data=data, x="pc2", y="pc3", hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[1])
# ax[1].set_xlim(-5,10)
# ax[1].set_ylim(-5,10)
pca_result_df = pd.DataFrame(pca_result, index=data.identifier)
pca_result_df['cre_line'] = data['cre_line'].values
# pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
PC1 = 0
PC2 = 1
PC3 = 3
PC4 = 4
fig,ax = plt.subplots(1, 3, figsize=(15,5))
ax = ax.ravel()
i=0
ax[i] = sns.scatterplot(data=pca_result_df, x=PC1, y=PC2, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
i+=1
ax[i] = sns.scatterplot(data=pca_result_df, x=PC2, y=PC3, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
i+=1
ax[i] = sns.scatterplot(data=pca_result_df, x=PC3, y=PC4, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
fig.tight_layout()
query_string = '''pc1>-100 and pc1<100 and pc2>-100 and pc2<100 and pc3>-100 and pc3<100'''
fig = px.scatter_3d(
data.query(query_string),
x='pc1',
y='pc2',
z='pc3',
color='cre_line',
)
fig.update_traces(
marker=dict(
size=3,
opacity=0.25
)
)
fig.update_layout(
margin=dict(l=30, r=30, t=10, b=10),
width=1200,
height=1000,
)
# fig.write_html("/home/dougo/code/dougollerenshaw.github.io/figures_to_share/2020.08.09_PCA_on_GLM.html")
# fig.show()
# # pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
# fig,ax = plt.subplots(n_components, n_components, figsize=(20,20))
# ax = ax.ravel()
# i = 0
# for PC1 in range(n_components):
# for PC2 in range(n_components):
# ax[i] = sns.scatterplot(data=pca_result_df, x=PC1, y=PC2, hue="cre_line",
# palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
# # ax[1] = sns.scatterplot(data=rspm, x="pc2", y="pc3", hue="cre_line",
# # palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[1])
# # ax[1].set_xlim(-100,100)
# # ax[1].set_ylim(-100,100)
feature_matrix = data[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
kmeans = KMeans(n_clusters=10)
kmeans_result = kmeans.fit_predict(feature_matrix)
data['kmeans_result_features'] = kmeans_result
data['kmeans_result_features'].value_counts()
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result, cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": '?'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(cols_for_clustering, rotation=90);
kmeans = KMeans(n_clusters=10)
kmeans_result = kmeans.fit_predict(pca_result)
rspm['kmeans_result'] = kmeans_result
rspm['kmeans_result'].value_counts()
"Then we applied consensus clustering to the PCs, by running K-means using the PCs 100 times until reaching a stable co-clustering association matrix, where each entry represents the probability of two units belonging to the same cluster." - Xiaoxuan paper
data.cre_line.unique()
cre_lines
data = rspm.dropna()
# fig, ax = plt.subplots(figsize=(6,10))
tmp = data.copy()
cre_lines = tmp.pop('cre_line')
lut = dict(zip(cre_lines.unique(), "rbg"))
row_colors = cre_lines.map(lut)
ax = sns.clustermap(data[cols_for_clustering], cmap='RdBu_r', row_colors=row_colors)
# ax.set_ylabel('cells')
# ax.set_xlabel('PC')
# ax.set_title('')
# ax.set_ylim(0, pca_result.shape[0])
# ax.set_xlim(0, pca_result.shape[1])
# ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(cols_for_clustering, rotation=90);